# Supplementary code for "Whisfusion: Parallel ASR Decoding via a Diffusion Transformer"

This document provides instructions for setting up the environment, preparing the data, and running the training and evaluation scripts to reproduce the main results presented in the paper.

<p align="center">
  <img src="assets/inference.gif" width="80%">
</p>

## 1. Setup

### 1.1. Requirements

This code was tested using Python 3.10 and CUDA 12.1.

#### Install PyTorch with CUDA 12.1
```bash
pip install torch==2.4.1 torchvision==0.19.1 torchaudio==2.4.1 --index-url https://download.pytorch.org/whl/cu121
```

#### Install FlashAttention from source
```bash
git clone --recurse-submodules --branch v2.6.3 https://github.com/Dao-AILab/flash-attention.git
cd flash-attention
pip install .
cd csrc/rotary && pip install .
cd ../layer_norm && pip install .
cd ../xentropy && pip install .
cd ../../.. && rm -rf flash-attention
```

#### Install other dependencies
```bash
pip install -r requirements.txt
```

## 2. Reproducing Main Results

The following steps outline the full process from data download to final evaluation.

### Step 1: Download Datasets

The following script will download and extract all necessary LibriSpeech partitions into the `data/raw/` directory.

**Note:** The full dataset is several dozen gigabytes and the download may take a significant amount of time.

```bash
chmod +x scripts/00_download_librispeech.sh
./scripts/00_download_librispeech.sh
```

### Step 2: Preprocess Audio Data

This script converts the raw audio files into Whisper encoder hidden states, which are saved as `.pt` files in `data/processed/`.

```bash
chmod +x scripts/01_preprocess_librispeech.sh
./scripts/01_preprocess_librispeech.sh
```

### Step 3: Download Pre-trained Models

This script downloads the pre-trained MDM decoder weights required for training.

```bash
chmod +x scripts/02_download_pretrained_model.sh
./scripts/02_download_pretrained_model.sh
```

### Step 4: Model Training

The Whisfusion model is trained using a 2-stage curriculum.

#### Stage 1: Adapter Training

This stage trains only the Cross-Attention adapter.

```bash
chmod +x scripts/03_train_stage1_adapter.sh
./scripts/03_train_stage1_adapter.sh
```

#### Stage 2: Full Decoder Fine-tuning

This stage fine-tunes the entire decoder and the adapter.

**Important:** Before running, you must update the `--pretrain_path` argument in `scripts/04_train_stage2_decoder_high_ratio.sh` to point to the `adapter_best.pt` file generated in Stage 1.

```bash
chmod +x scripts/04_train_stage2_decoder_high_ratio.sh
./scripts/04_train_stage2_decoder_high_ratio.sh
```

The final model, `model_best.pt`, will be saved in the `out/stage2_.../` directory.

### Step 5: Evaluation

These scripts reproduce the performance metrics reported in the paper.

#### Evaluating Whisfusion

**Important:** Before running, you must update the `--adapter_path` argument in `scripts/05_evaluate_whisfusion.sh` to the path of the final model generated in Stage 2.

```bash
chmod +x scripts/05_evaluate_whisfusion.sh
./scripts/05_evaluate_whisfusion.sh
```

#### Evaluating Whisper Baselines

This script measures the performance of the Whisper baseline models. You can modify the models list inside the script to evaluate different variants.

```bash
chmod +x scripts/06_evaluate_whisper.sh
./scripts/06_evaluate_whisper.sh
```

All evaluation results will be saved as JSON files in the `evaluation_results/` directory.

## 3. Pre-trained Model Weights

Due to the double-blind review policy and submission size limits, we cannot provide the final trained model weights in this appendix. We are committed to full reproducibility and will make the pre-trained weights of our final model publicly available upon publication of the paper.

## 4. Project Structure

```
Whisfusion/
├── scripts/                   # All executable scripts
├── src/
│   ├── data/                  # Data preprocessing modules
│   ├── training/              # Training scripts
│   ├── evaluation/            # Evaluation scripts
│   └── lit_gpt/               # Model architecture
├── data/
│   ├── raw/                   # Downloaded LibriSpeech data
│   └── processed/             # Preprocessed Whisper features
├── sample_data/               # Sample data for quick testing
│   ├── raw/                   # Sample raw audio files
│   └── processed/             # Sample preprocessed features
├── pretrained_models/         # Pre-trained model weights
└── out/                       # Training outputs and checkpoints
```

## 5. Hardware Requirements

- **Training**: 4× NVIDIA A100 (40GB VRAM) GPUs were used to train the model with the batch sizes specified in the scripts.
- **Evaluation**: 1× NVIDIA A100 GPU was used for all reported latency measurements.
- **Storage**: ~700GB of free space is required to store the raw dataset and the pre-processed hidden state features.